[TF] Remove unbroadcast(to:) and improve derivative performance.#24408
[TF] Remove unbroadcast(to:) and improve derivative performance.#24408rxwei wants to merge 2 commits intotensorflowfrom
unbroadcast(to:) and improve derivative performance.#24408Conversation
|
@pschuh I don't have time to look into it in the next couple of days, but it'd be great if you could take a look! |
In the pullback for operators that broadcast, use `Raw.broadcastGradientArgs(s0:s1:)` to compute reduction indices instead of using the inefficient `unbroadcast(to:)`. `unbroadcast(to:)` was introduced only for defining derivatives for broadcasting operators and has no practical use, so now we remove it. Operators affected: - `Tensor.+(_:_:)` - `Tensor.-(_:_:)` - `Tensor.*(_:_:)` - `Tensor./(_:_:)` - `min(_:_:)` - `max(_:_:)` - `pow(_:_:)`
4511398 to
0902283
Compare
| static func _vjpSubtract( | ||
| lhs: Tensor, rhs: Scalar | ||
| ) -> (Tensor, (Tensor) -> (Tensor, Scalar)) { | ||
| return (lhs - rhs, { v in (v, 0 - v.sum().scalarized()) }) |
There was a problem hiding this comment.
This is some legacy code introduced in the early days when a where clause on @differentiable was not supported. Now it is fixed for better.
| -v.unbroadcast(toShape: rhsShape)) | ||
| let (lhsAxes, rhsAxes) = | ||
| Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape) | ||
| return (v.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape), |
There was a problem hiding this comment.
I haven't looked too much, but I suspect that this extra reshape is not necessary. The lhsAxes should be sufficient to recover the original shape.
There was a problem hiding this comment.
That’s what I tried initially (more specifically, ‘sum(alongAxes:)’) but it didn’t work.
There was a problem hiding this comment.
@pschuh @rxwei the reshape is needed for handling dimensions with size 1. For example, say you do:
// x has shape [B, 5]
// y has shape [5]
// result has shape [B, 5]
let result = x + yIn this case, the broadcast indices for the gradient wrt to y will be [0] and so we’ll do something like:
let yGrad = seed.sum(alongAxes: [0]) // no reshape needed.Now, let y have shape [1, 5], which still broadcasts correctly for this example. The broadcast indices will now also be the same for the gradient (i.e., [0]). However, we need to do the reshape to recover the dimensions of size 1. Thus, the gradient needs to be computed as:
let yGrad = seed.sum(alongAxes: [0]).reshape(to: y.shape)Having said that, I have a working implementation of these changes that I had made as part of a future swift-apis PR. I’ll try to open a PR here for this ASAP, but haven’t gotten the chance yet because I’m traveling to ICLR this week.
There was a problem hiding this comment.
To make sure we don't regress in the future, could you add a quick test case in your other PR to swift-apis? :-)
There was a problem hiding this comment.
Yeap, I will go ahead and add that. Given that the merge already happened, is it ok to make this change after we move stdlib to swift-apis? I'll update the two PRs doing the move tonight.
…rmance. The inefficiency of `unbroadcast(toShape:)`, `unbroadcast(to:)`, and `unbroadcast(like:)` has caused significant performance problems during model training because it's performing a lot of TensorFlow operations to achieve axis calculation. We were forced to implement it this way in the early GPE era when neither send/receive nor per-op dispatch was available. This PR reimplements the unbroadcast operations in terms of host-side logic to compute axes to reduce along. This significantly reduces the TensorFlow opreation dispatch overhead. The base implementation changed from `broadcast(toShape:)` to `broadcast(to:)`. With the new implementation, differentiating broadcasting operators is 37% faster (see simple test script [here](https://gist.github.com/rxwei/e1488cac5379ba2bc3aff7490e18158f)). Note: - Since we now rely on the TensorFlow runtime less, more precondition checks and assertions are added to the newly implemented `unbroadcast(to:)` method. - The part of swiftlang#24408 that uses `Raw.broadcastGradientArgs(s0:s1:)` is still necessary for broadcasting binary operations to become faster. TODO: - Change `unbroadcast(toShape:)` tests added by swiftlang#24899 to use `unbroadcast(to:)`, since `unbroadcast(to:)` is now the base implementation.
…rmance. (#24907) The inefficiency of `unbroadcast(toShape:)`, `unbroadcast(to:)`, and `unbroadcast(like:)` has caused significant performance problems during model training because it's performing a lot of TensorFlow operations to achieve axis calculation. We were forced to implement it this way in the early GPE era when neither send/receive nor per-op dispatch was available. This PR reimplements the unbroadcast operations in terms of host-side logic to compute axes to reduce along. This significantly reduces the TensorFlow opreation dispatch overhead. The base implementation changed from `broadcast(toShape:)` to `broadcast(to:)`. With the new implementation, differentiating broadcasting operators is 37% faster (see simple test script [here](https://gist.github.com/rxwei/e1488cac5379ba2bc3aff7490e18158f)). Note: - Since we now rely on the TensorFlow runtime less, more precondition checks and assertions are added to the newly implemented `unbroadcast(to:)` method. - The part of #24408 that uses `Raw.broadcastGradientArgs(s0:s1:)` is still necessary for broadcasting binary operations to become faster. TODO: - Change `unbroadcast(toShape:)` tests added by #24899 to use `unbroadcast(to:)`, since `unbroadcast(to:)` is now the base implementation.
|
@rxwei Done in tensorflow/swift-apis#142. |
Re-implementation of swiftlang/swift#24408. In the pullback for operators that broadcast, use `Raw.broadcastGradientArgs(s0:s1:)` to compute reduction indices instead of using the inefficient `unbroadcast(to:)`.
In the pullback for operators that broadcast, use
Raw.broadcastGradientArgs(s0:s1:)to compute reduction indices instead of using the inefficientunbroadcast(to:).unbroadcast(to:)was introduced only for defining derivatives for broadcasting operators and has no practical use, so now we remove it.Operators affected:
Tensor.+(_:_:)Tensor.-(_:_:)Tensor.*(_:_:)Tensor./(_:_:)min(_:_:)max(_:_:)pow(_:_:)TODO before merging:
Currently there's a failure on
+(see the test being commented out). Figure out what's wrong and fix it.